# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

from typing import Optional

from ..core.tensor import LocalTensor
from ..core.utils import require_jit, global_builder
from .tiling import SoftmaxTiling


@require_jit

def softmax(dst: LocalTensor, sum: LocalTensor, max: LocalTensor, src: LocalTensor, tiling: SoftmaxTiling,
            temp_buffer: Optional[LocalTensor] = None, reuse_source: bool = False, basic_block: bool = False,
            data_format_nz: bool = False) -> None:
    """
    将输入tensor[m0, m1, ...mt, n]（t大于等于0）的非尾轴长度相乘的结果看作m，则输入tensor的shape看作[m, n]。
    为方便理解，通过Python脚本实现的方式，表达其计算公式（以输入为ND格式为例）如下，其中src是源操作数（输入），dst、sum、max为目的操作数（输出）。

    .. code-block:: python

        def softmax(src):
            # 基于last轴进行rowmax（按行取最大值）处理
            max = np.max(src, axis=-1, keepdims=True)
            sub = src - max
            exp = np.exp(sub)
            # 基于last轴进行rowsum（按行求和）处理
            sum = np.sum(exp, axis=-1, keepdims=True)
            dst = exp / sum
            return dst, max, sum

    **对应的Ascend C函数原型**

    - 接口框架申请临时空间
    
      - LocalTensor的数据类型相同 

        .. code-block:: c++

            template <typename T, bool isReuseSource = false, bool isBasicBlock = false, 
            bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
            __aicore__ inline void SoftMax(const LocalTensor<T>& dstTensor, const LocalTensor<T>& sumTensor, 
                                          const LocalTensor<T>& maxTensor, const LocalTensor<T>& srcTensor, 
                                          const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {})

      - LocalTensor的数据类型不同

        .. code-block:: c++

            template <typename T, bool isReuseSource = false, bool isBasicBlock = false, 
            bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
            __aicore__ inline void SoftMax(const LocalTensor<half>& dstTensor, const LocalTensor<float>& sumTensor, 
                                          const LocalTensor<float>& maxTensor, const LocalTensor<half>& srcTensor, 
                                          const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {})

      - 不带sumTensor和maxTensor参数

        .. code-block:: c++

            template <typename T, bool isReuseSource = false, bool isBasicBlock = false, 
            const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
            __aicore__ inline void SoftMax(const LocalTensor<T>& dstTensor, const LocalTensor<T>& srcTensor, 
                                        const SoftMaxTiling& tiling, const SoftMaxShapeInfo& softmaxShapeInfo = {})

    - 通过sharedTmpBuffer入参传入临时空间

      - LocalTensor的数据类型相同

        .. code-block:: c++

            template <typename T, bool isReuseSource = false, bool isBasicBlock = false, 
            bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
            __aicore__ inline void SoftMax(const LocalTensor<T>& dstTensor, const LocalTensor<T>& sumTensor, 
                                          const LocalTensor<T>& maxTensor, const LocalTensor<T>& srcTensor, 
                                          const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling, 
                                          const SoftMaxShapeInfo& softmaxShapeInfo = {})
    
      - LocalTensor的数据类型不同

        .. code-block:: c++

            template <typename T, bool isReuseSource = false, bool isBasicBlock = false, 
            bool isDataFormatNZ = false, const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
            __aicore__ inline void SoftMax(const LocalTensor<half>& dstTensor, const LocalTensor<float>& sumTensor, 
                                          const LocalTensor<float>& maxTensor, const LocalTensor<half>& srcTensor, 
                                          const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling, 
                                          const SoftMaxShapeInfo& softmaxShapeInfo = {})

      - 不带sumTensor和maxTensor参数

        .. code-block:: c++

            template <typename T, bool isReuseSource = false, bool isBasicBlock = false, 
            const SoftmaxConfig& config = SOFTMAX_DEFAULT_CFG>
            __aicore__ inline void SoftMax(const LocalTensor<T>& dstTensor, const LocalTensor<T>& srcTensor, 
                                          const LocalTensor<uint8_t>& sharedTmpBuffer, const SoftMaxTiling& tiling, 
                                          const SoftMaxShapeInfo& softmaxShapeInfo = {})
            
    **参数说明**

    - T：操作数的数据类型。
    - isReuseSource：该参数预留，传入默认值false即可。
    - isBasicBlock：srcTensor和dstTensor的shape信息和Tiling切分策略满足基本块要求的情况下，可以使能该参数用于提升性能，默认不使能。
    - isDataFormatNZ：当前输入输出的数据格式是否为NZ格式，默认数据格式为ND，即默认取值为false。
    - config：结构体模板参数，此参数可选配，SoftmaxConfig类型。
    - dst：目的操作数。
    - sum：目的操作数。
    - max：目的操作数。
    - src：源操作数。
    - sharedTmpBuffer：临时空间。
    - tiling：SoftMax计算所需Tiling信息。
    - softmaxShapeInfo：src的shape信息。

    **约束说明**

    - src和dst的Tensor空间可以复用。
    - sumTensor和maxTensor为输出，并且last轴长度必须固定32Byte，非last轴大小需要和src以及dst保持一致。
    - sumTensor和maxTensor的数据类型需要保持一致。
    - 操作数地址对齐要求请参见通用地址对齐约束。
    - 不支持sharedTmpBuffer与源操作数和目的操作数地址重叠。
    - 当参数softmaxShapeInfo中srcM != oriSrcM 或者 srcK != oriSrcK时，
    开发者需要对GM上的原始输入(oriSrcM, oriSrcK)在M或K方向补齐数据到(srcM, srcK)，补齐的数据会参与部分运算，
    在输入输出复用的场景下，API的计算结果会覆盖srcTensor中补齐的原始数据，在输入输出不复用的场景下，
    API的计算结果会覆盖dstTensor中对应srcTensor补齐位置的数据。

    **调用示例**

    .. code-block:: python

        src_local = in_queue_src.deque(T)
        sum_temp_local = sum_queue.alloc_tensor(T)
        max_temp_local = max_queue.alloc_tensor(T)
        dst_local = out_queue_dst.alloc_tensor(T)

        src_shape = asc.SoftMaxShapeInfo(height, width, height, width);
        asc.adv.softmax(dst_local, sum_temp_local, max_temp_local, srcLocal, tiling, src_shape);

        out_queue_dst.EnQue(dstLocal)
        max_queue.free_tensor(max_temp_local)
        sum_queue.free_tensor(sum_temp_local)
        in_queue_src.free_tensor(src_local)
    """
    temp_buffer = temp_buffer.to_ir() if temp_buffer is not None else None
    global_builder.get_ir_builder().create_asc_SoftMaxOp(reuseSource=reuse_source, basicBlock=basic_block,
                                                         dataFormatNZ=data_format_nz, dst=dst.to_ir(),
                                                         sumTensor=sum.to_ir(), maxTensor=max.to_ir(), src=src.to_ir(),
                                                         sharedTmpBuffer=temp_buffer, tiling=tiling.to_ir(),
                                                         softmaxShapeInfo=None)
